{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 准备二阶段训练数据集\n",
    "### 1. 用一阶段模型把所有query转成向量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from config import *\n",
    "from task_sentence_embedding_FinanceFAQ_step1_1 import model\n",
    "\n",
    "# 读取标问和所有语料\n",
    "q_std_list = pd.read_csv(q_std_file, sep=\"\\t\", names=['c']).c.tolist()\n",
    "q_corpus = pd.read_csv(q_corpus_file, sep=\"\\t\", names=['c']).c.tolist()\n",
    "\n",
    "# get embeddings\n",
    "q_std_sentence_embeddings = model.encode(q_std_list)\n",
    "np.save(fst_q_std_vectors_file, q_std_sentence_embeddings.numpy())\n",
    "q_corpus_sentence_embeddings = model.encode(q_corpus)\n",
    "np.save(fst_q_corpus_vectors_file, q_corpus_sentence_embeddings.numpy())\n",
    "print('标准问向量路径：', fst_q_std_vectors_file)\n",
    "print('所有语料保存向量路径：', fst_q_corpus_vectors_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. 为每个q_sim找到topK的的q_std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from task_sentence_embedding_FinanceFAQ_step1_1 import model\n",
    "from config import *\n",
    "from utils import *\n",
    "\n",
    "# 读取q_std、q_corpus语料和向量\n",
    "q_std_list, q_std_sentence_embeddings, q_all, q_all_sentence_embeddings_dict = read_q_std_q_corpus(q_std_file, fst_q_std_vectors_file, q_corpus_file, fst_q_corpus_vectors_file)\n",
    "\n",
    "print('----加载一阶段训练(标问-相似问)数据集', fst_train_file)\n",
    "df_eval = pd.read_csv(fst_train_file, sep=\"\\t\")\n",
    "print(\"shape: \", df_eval.shape)\n",
    "df_eval = df_eval[df_eval.q_std.isin(q_std_list)]\n",
    "print(\"shape: \", df_eval.shape)\n",
    "\n",
    "df_eval = cal_performance(model, q_all_sentence_embeddings_dict, q_std_sentence_embeddings, q_std_list, df_eval, K=20)\n",
    "df_eval.to_csv(fst_std_data_results, index=None, sep=\"\\t\")\n",
    "df_eval.iloc[3:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. 二阶段正负样本生成\n",
    "预测的topK中和q_std一致的为正样本，不一致的为困难负样本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xdf = df_eval.copy(deep=True)\n",
    "# xdf['q_std_pred_list']=xdf.q_std_pred_list.apply(lambda v:eval(v))\n",
    "print('预测结果中和q_std不一致的'.center(60, '-'))\n",
    "xdf['q_std_pred_list_else'] = xdf.apply(lambda row: [v for v in row['q_std_pred_list'] if v[0] != row['q_std']], axis=1)\n",
    "xdf['q_std_pred_list_else_v1'] = xdf.q_std_pred_list_else.apply(lambda v: [m[0] for m in v])  # 负样本的文本\n",
    "xdf['q_std_pred_list_else_v2'] = xdf.q_std_pred_list_else.apply(lambda v: [m[1] for m in v])  # 负样本的概率\n",
    "\n",
    "print('组织正负样本'.center(60, '-'))\n",
    "xdf['pairs'] = xdf.apply(lambda row: ['1' + '\\t' + row['q_sim'] + '\\t' + row['q_std'] + '\\t' + '1'] + [\n",
    "    '0' + '\\t' + row['q_sim'] + '\\t' + v[0] + '\\t' + str(v[1]) for v in row['q_std_pred_list_else'][0:10]], axis=1)\n",
    "print(xdf.iloc[3]['pairs'])\n",
    "\n",
    "print('单独处理正负样本'.center(60, '-'))\n",
    "q_sim_list = xdf.q_sim.unique().tolist()\n",
    "q_std_list = xdf.q_std.unique().tolist()\n",
    "q_sim_dict = {q_sim_list[i]: i for i in range(0, len(q_sim_list))}\n",
    "q_std_dict = {q_std_list[i]: i for i in range(0, len(q_std_list))}\n",
    "pairs = xdf.pairs.tolist()\n",
    "pairs_list = [v.split('\\t') for vlist in pairs for v in vlist]\n",
    "pairs_df = pd.DataFrame(pairs_list, columns=['label', 'q_sim', 'q_std', 'prob'])\n",
    "print(pairs_df.drop_duplicates(['q_std', 'q_sim']).shape)\n",
    "pairs_df.head()\n",
    "\n",
    "pairs_df_2 = pairs_df.sort_values('label', ascending=0).drop_duplicates(['q_sim', 'q_std'])\n",
    "pairs_df_final = pairs_df_2\n",
    "print(pairs_df_final.shape, pairs_df.shape)\n",
    "\n",
    "print('对于每一个q_sim，仅保留概率最高的10条样本'.center(60, '-'))\n",
    "pairs_df_final['prob'] = pairs_df_final.prob.astype(\"float\")\n",
    "pairs_df_final['nrank'] = pairs_df_final.groupby(['label', 'q_sim'])['prob'].rank(ascending=0, method='first')\n",
    "df_final = pairs_df_final[pairs_df_final.nrank <= 9].reset_index(drop=True)\n",
    "df_final['sim_idx'] = df_final.q_sim.map(q_sim_dict)\n",
    "df_final['std_idx'] = df_final.q_std.map(q_std_dict)\n",
    "df_final = df_final.sort_values(['sim_idx', 'label', 'nrank'], ascending=[1, 0, 1])[['label', 'q_sim', 'q_std']].reset_index(drop=True)\n",
    "\n",
    "print('对于每一条标问，随机挑选一条样本作为dev集合'.center(60, '-'))\n",
    "xdf['dev_rnd'] = xdf.q_std.apply(lambda v: np.random.rand())\n",
    "xdf['nrank_dev'] = xdf.groupby('q_std')['dev_rnd'].rank(ascending=0, method='first')\n",
    "q_sim_choose_dev = xdf[xdf.nrank_dev <= 1].drop_duplicates(['q_sim']).q_sim.tolist()\n",
    "df_train = df_final.copy(deep=True)\n",
    "df_dev = df_final[df_final.q_sim.isin(q_sim_choose_dev)]\n",
    "print('第二阶段train集: ', sec_train_file, ', shape: ', df_train.shape)\n",
    "df_train[['label', 'q_std', 'q_sim']].to_csv(sec_train_file, sep=\"\\t\", index=None, header=False)\n",
    "print('第二阶段dev集: ', sec_dev_file, ', shape', df_dev.shape)\n",
    "df_dev[['label', 'q_std', 'q_sim']].to_csv(sec_test_file, sep=\"\\t\", index=None, header=False)\n",
    "df_dev[['label', 'q_std', 'q_sim']].to_csv(sec_dev_file, sep=\"\\t\", index=None, header=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.8 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8 (default, Apr 13 2021, 15:08:03) [MSC v.1916 64 bit (AMD64)]"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "e42634819b8c191a5d07eaf23810ff32516dd8d3875f28ec3e488928fbd3c187"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
